{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 30分钟搞懂它，加油！\n",
    "\n",
    "\n",
    "## 导入数据\n",
    "\n",
    "PyTorch 中有两个处理数据的工具，`torch.utils.data.DataLoader` 和 `torch.utils.data.Dataset`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor, Lambda, Compose\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# 从 datasets 下载开放数据集\n",
    "training_data =  datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")\n",
    "\n",
    "test_data = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将 `Dataset` 作为参数传递给 `DataLoader`，`DataLoader` 在我们的数据集上包装了一个可迭代的对象。并支持自动批处理（batche）、采样（sample）、混洗（shuffle）和多进程数据加载（multiprocess data loading）。这里我们定义 batch 为64，即数据加载器迭代中的每个元素将返回一批 64 个特征和标签的数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])\n",
      "Shape of y:  torch.Size([64]) torch.int64\n"
     ]
    }
   ],
   "source": [
    "batch_size = 64\n",
    "\n",
    "# Create data loaders.\n",
    "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n",
    "\n",
    "for X, y in test_dataloader:\n",
    "    print(\"Shape of X [N, C, H, W]: \", X.shape)\n",
    "    print(\"Shape of y: \", y.shape, y.dtype)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建模型\n",
    "\n",
    "创建了一个继承自 `nn.Module` 的类，在 PyTorch 定义神经网络。在 `__init__` 函数中定义网络层，并在函数中指定数据将如何通过 `forward` 中的网络。我们可以将神经网络移至 GPU，来加入模型的训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device\n",
      "NeuralNetwork(\n",
      "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): Linear(in_features=784, out_features=512, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=512, out_features=512, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=512, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 找到可以用于训练的 GPU\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(\"Using {} device\".format(device))\n",
    "\n",
    "# 定义模型\n",
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NeuralNetwork, self).__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 10)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n",
    "\n",
    "model = NeuralNetwork().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "--------------\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 优化模型参数\n",
    "定义一个损失函数和一个优化器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在单个训练循环中，模型对训练数据集进行预测（分批提供给它），并反向传播预测误差从而调整模型的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def train(dataloader, model, loss_fn, optimizer):\n",
    "    size = len(dataloader.dataset)\n",
    "    model.train()\n",
    "    for batch, (X, y) in enumerate(dataloader):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "\n",
    "        # 计算预测误差\n",
    "        pred = model(X)\n",
    "        loss = loss_fn(pred, y)\n",
    "\n",
    "        # 反向传播\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            loss, current = loss.item(), batch * len(X)\n",
    "            print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "根据测试数据集检查模型的性能，以确保模型正在学习。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def test(dataloader, model, loss_fn):\n",
    "    size = len(dataloader.dataset)\n",
    "    num_batches = len(dataloader)\n",
    "    model.eval()\n",
    "    test_loss, correct = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for X, y in dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            pred = model(X)\n",
    "            test_loss += loss_fn(pred, y).item()\n",
    "            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
    "    test_loss /= num_batches\n",
    "    correct /= size\n",
    "    print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "打印每一次迭代过程中的监控数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 2.308198  [    0/60000]\n",
      "loss: 2.303675  [ 6400/60000]\n",
      "loss: 2.279196  [12800/60000]\n",
      "loss: 2.273016  [19200/60000]\n",
      "loss: 2.256729  [25600/60000]\n",
      "loss: 2.224995  [32000/60000]\n",
      "loss: 2.228598  [38400/60000]\n",
      "loss: 2.191386  [44800/60000]\n",
      "loss: 2.198501  [51200/60000]\n",
      "loss: 2.156352  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 48.6%, Avg loss: 2.157153 \n",
      "\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 2.171853  [    0/60000]\n",
      "loss: 2.161133  [ 6400/60000]\n",
      "loss: 2.100347  [12800/60000]\n",
      "loss: 2.119125  [19200/60000]\n",
      "loss: 2.059523  [25600/60000]\n",
      "loss: 2.006031  [32000/60000]\n",
      "loss: 2.032665  [38400/60000]\n",
      "loss: 1.945392  [44800/60000]\n",
      "loss: 1.962137  [51200/60000]\n",
      "loss: 1.883276  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 55.2%, Avg loss: 1.881818 \n",
      "\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 1.920210  [    0/60000]\n",
      "loss: 1.885050  [ 6400/60000]\n",
      "loss: 1.766944  [12800/60000]\n",
      "loss: 1.817860  [19200/60000]\n",
      "loss: 1.691585  [25600/60000]\n",
      "loss: 1.653540  [32000/60000]\n",
      "loss: 1.682247  [38400/60000]\n",
      "loss: 1.572053  [44800/60000]\n",
      "loss: 1.601556  [51200/60000]\n",
      "loss: 1.506832  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 59.6%, Avg loss: 1.520164 \n",
      "\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 1.587611  [    0/60000]\n",
      "loss: 1.549974  [ 6400/60000]\n",
      "loss: 1.404029  [12800/60000]\n",
      "loss: 1.486433  [19200/60000]\n",
      "loss: 1.354578  [25600/60000]\n",
      "loss: 1.356298  [32000/60000]\n",
      "loss: 1.372628  [38400/60000]\n",
      "loss: 1.290354  [44800/60000]\n",
      "loss: 1.321386  [51200/60000]\n",
      "loss: 1.236207  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 63.0%, Avg loss: 1.258429 \n",
      "\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 1.335369  [    0/60000]\n",
      "loss: 1.314450  [ 6400/60000]\n",
      "loss: 1.152564  [12800/60000]\n",
      "loss: 1.264987  [19200/60000]\n",
      "loss: 1.132541  [25600/60000]\n",
      "loss: 1.157186  [32000/60000]\n",
      "loss: 1.176993  [38400/60000]\n",
      "loss: 1.112688  [44800/60000]\n",
      "loss: 1.145133  [51200/60000]\n",
      "loss: 1.074617  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 64.5%, Avg loss: 1.092484 \n",
      "\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "epochs = 5\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train(train_dataloader, model, loss_fn, optimizer)\n",
    "    test(test_dataloader, model, loss_fn)\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "--------------\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存模型\n",
    "\n",
    "保存模型的常用方法是序列化内部状态字典（包含模型参数）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型已保存成功！\n"
     ]
    }
   ],
   "source": [
    "torch.save(model.state_dict(), \"model.pth\")\n",
    "print(\"模型已保存成功！\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载模型\n",
    "\n",
    "加载模型的过程包括重新创建模型结构并将状态字典加载到其中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = NeuralNetwork()\n",
    "model.load_state_dict(torch.load(\"model.pth\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用加载的模型实现预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "预测结果是: \"Ankle boot\", 实际标签是: \"Ankle boot\"\n"
     ]
    }
   ],
   "source": [
    "classes = [\n",
    "    \"T-shirt/top\",\n",
    "    \"Trouser\",\n",
    "    \"Pullover\",\n",
    "    \"Dress\",\n",
    "    \"Coat\",\n",
    "    \"Sandal\",\n",
    "    \"Shirt\",\n",
    "    \"Sneaker\",\n",
    "    \"Bag\",\n",
    "    \"Ankle boot\",\n",
    "]\n",
    "\n",
    "model.eval()\n",
    "x, y = test_data[0][0], test_data[0][1]\n",
    "with torch.no_grad():\n",
    "    pred = model(x)\n",
    "    predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
    "    print(f'预测结果是: \"{predicted}\", 实际标签是: \"{actual}\"')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8rc1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
